import json
import os.path
import pickle
import numpy as np
from typing import Iterable, Tuple, Callable, Optional, Dict

from Utils import logger


def pkl_file_reader_gen(file_path: str) -> Iterable:
    import bz2
    import _pickle as pkl
    if '.bz2' in file_path:
        with bz2.BZ2File(file_path, "rb") as file:
            while True:
                try:
                    yield pkl.load(file)
                except EOFError:
                    break
                except pickle.UnpicklingError as e:
                    logger().error('Utils.pkl_file_reader_gen', e,
                                   f'Failed unpickle for BZ2 on file: {file_path}\n, {str(e)}')

    else:
        with open(file_path, 'rb') as file:
            while True:
                try:
                    yield pkl.load(file)
                except EOFError:
                    break
                except pickle.UnpicklingError as e:
                    logger().error('Utils.pkl_file_reader_gen', e,
                                   f'Failed unpickle for pkl on file: {file_path}\n, {str(e)}')


def get_feature_name_from_stat_name(stat_name) -> str:
    all_single_layer_feature = ('geometric_mean', 'harmonic_mean', 'geometric_std', 'mean', 'variance', 'median',
                                'std', 'max', 'min', 'covariance', 'skewness', 'kurtosis',
                                'anderson_norm',  'anderson_expon', 'anderson_logistic', 'anderson_gumbel',
                                'q-th_percentile', 'L1_norm', 'L2_norm')
    for feature_name in all_single_layer_feature:
        if feature_name in stat_name:
            return feature_name
    return 'no_feature'


def flatten_problematic_stats(stat):
    flatten_stat = list()
    for curr in stat:
        if isinstance(curr, list):
            flatten_stat += curr
        elif isinstance(curr, np.ndarray):
            flatten_stat += curr.tolist()
        else:
            flatten_stat.append(curr)
    return flatten_stat


def results_file_gen(file_path: str) -> Iterable[Tuple]:
    with open(file_path, 'r') as file:
        try:
            while True:
                model_name = json.loads(next(file)).replace('--', '').replace('model', '').strip()
                results = json.loads(next(file))
                yield model_name, results
        except StopIteration:
            pass


def get_model_id_from_file_name(file_name: str) -> str:
    import re
    clean_name = os.path.basename(file_name).split('.')[0]
    return re.findall(r"\d+", clean_name)[-1]


def load_save_results(save_path: str = Optional[str], force_func: bool = False):
    def inner(func: Callable):
        def load_save_results_inner(*args, **kwargs):
            from Utils import logger
            if save_path is None:
                return func(*args, **kwargs)
            if os.path.exists(save_path) and not force_func:
                logger().log(f'Results file: {save_path} exists skipping function: {func.__name__}')
                with open(save_path, 'rb') as file:
                    func_results = pickle.load(file)
            else:
                func_results = func(*args, **kwargs)
                os.makedirs(save_path[:save_path.rfind('/')], exist_ok=True)
                with open(save_path, 'wb') as file:
                    pickle.dump(func_results, file)
            return func_results

        return load_save_results_inner

    return inner


def function_start_save_params(local_params: Dict, extra_data: Optional[Dict], config, save_path: str):
    """
    Run at function start to save all model arguemnts
    :param local_params:
    :param extra_data:
    :param config:
    :param save_path:
    :return:
    """
    to_save = local_params.copy()
    if extra_data is not None:
        to_save.update(extra_data)
    if config is not None:
        to_save['config'] = config.to_json()
    file_name = f'{save_path}.json' if 'json' not in save_path else save_path
    logger().log('function_start_save_params', str(to_save), f'\nSaving to: {file_name}')
    with open(file_name, 'w') as jf:
        json.dump(to_save, jf)
